{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "# for i in range(5):\n",
    "#     url = f'https://f000.backblazeb2.com/file/malay-dataset/text-similarity/mnli/translated-{i}.json'\n",
    "#     os.system(f'wget {url}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(7):\n",
    "    url = f'https://f000.backblazeb2.com/file/malay-dataset/text-similarity/snli/part{i + 1}.json'\n",
    "    os.system(f'wget {url}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "\n",
    "def cleaning(string):\n",
    "    string = string.replace('\\n', ' ')\n",
    "    string = re.sub(r'[ ]+', ' ', string).strip()\n",
    "    return string"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "part2.json\n",
      "part2.json 49998\n",
      "part7.json\n",
      "part7.json 30934\n",
      "part1.json\n",
      "part1.json 50000\n",
      "part4.json\n",
      "part4.json 50000\n",
      "part6.json\n",
      "part6.json 100000\n",
      "part3.json\n",
      "part3.json 50000\n",
      "part5.json\n",
      "part5.json 49998\n"
     ]
    }
   ],
   "source": [
    "from glob import glob\n",
    "import tensorflow as tf\n",
    "import json\n",
    "from pathlib import Path\n",
    "\n",
    "labels = {'contradiction': 'percanggahan', 'entailment': 'berkait'}\n",
    "\n",
    "files = glob('part*.json')\n",
    "\n",
    "filename = 'snli.tsv'\n",
    "with tf.io.gfile.GFile(filename, 'w') as outfile:\n",
    "    for file in files:\n",
    "        print(file)\n",
    "        with open(file) as fopen:\n",
    "            data = json.load(fopen)\n",
    "        print(file, len(data))\n",
    "\n",
    "        s = Path(file).stem\n",
    "        for i in range(len(data)):\n",
    "            if len(data[i]) != 2:\n",
    "                continue\n",
    "\n",
    "            label = labels.get(data[i][0], data[i][0])\n",
    "            splitted = data[i][1].split(' <> ')\n",
    "            if len(splitted) != 2:\n",
    "                continue\n",
    "            q = f'ayat1: {cleaning(splitted[0])} ayat2: {(splitted[1])}'\n",
    "            outfile.write('%s\\t%s\\n' % (q, label))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "with open('translated-4.json') as fopen:\n",
    "    data = json.load(fopen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "translated-1.json\n",
      "translated-1.json 99874\n",
      "translated-4.json\n",
      "translated-4.json 9994\n",
      "translated-3.json\n",
      "translated-3.json 92581\n",
      "translated-0.json\n",
      "translated-0.json 99871\n",
      "translated-2.json\n",
      "translated-2.json 99890\n"
     ]
    }
   ],
   "source": [
    "files = glob('translated-*.json')\n",
    "\n",
    "filename = 'mnli.tsv'\n",
    "with tf.io.gfile.GFile(filename, 'w') as outfile:\n",
    "    for file in files:\n",
    "        print(file)\n",
    "        with open(file) as fopen:\n",
    "            data = json.load(fopen)\n",
    "        print(file, len(data))\n",
    "\n",
    "        s = Path(file).stem\n",
    "        for i in range(len(data)):\n",
    "            if len(data[i]) != 3:\n",
    "                continue\n",
    "\n",
    "            label = labels.get(data[i][1], data[i][1])\n",
    "            splitted = data[i][2].split(' <> ')\n",
    "            if len(splitted) != 3:\n",
    "                continue\n",
    "            q = f'ayat1: {cleaning(splitted[0])} ayat2: {(splitted[1])}'\n",
    "            outfile.write('%s\\t%s\\n' % (q, label))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "urls = [\n",
    "'https://f000.backblazeb2.com/file/malay-dataset/text-similarity/quora/0-100k.json',\n",
    "'https://f000.backblazeb2.com/file/malay-dataset/text-similarity/quora/100k-200k.json',\n",
    "'https://f000.backblazeb2.com/file/malay-dataset/text-similarity/quora/200k-300k.json',\n",
    "'https://f000.backblazeb2.com/file/malay-dataset/text-similarity/quora/300k-400k.json',\n",
    "'https://f000.backblazeb2.com/file/malay-dataset/text-similarity/quora/400k-500k.json']\n",
    "\n",
    "for url in urls:\n",
    "    os.system(f'wget {url}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0-100k.json 99993\n",
      "400k-500k.json 4290\n",
      "300k-400k.json 99995\n",
      "200k-300k.json 99992\n",
      "100k-200k.json 100000\n"
     ]
    }
   ],
   "source": [
    "files = glob('*0*-*k.json')\n",
    "labels = {0: 'tak sama', 1: 'sama'}\n",
    "\n",
    "filename = 'quora.tsv'\n",
    "with tf.io.gfile.GFile(filename, 'w') as outfile:\n",
    "    for file in files:\n",
    "        with open(file) as fopen:\n",
    "            data = json.load(fopen)\n",
    "        print(file, len(data))\n",
    "        s = Path(file).stem\n",
    "\n",
    "        for i in range(len(data)):\n",
    "            if len(data[i]) != 2:\n",
    "                continue\n",
    "            label = labels[data[i][1]]\n",
    "            splitted = data[i][0].split(' <> ')\n",
    "            if len(splitted) != 2:\n",
    "                continue\n",
    "            q = f'soalan1: {cleaning(splitted[0])} soalan2: {cleaning(splitted[1])}'\n",
    "            outfile.write('%s\\t%s\\n' % (q, label))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import tensorflow_datasets as tfds\n",
    "from t5.data import preprocessors as prep\n",
    "import functools\n",
    "import t5\n",
    "import gin\n",
    "\n",
    "gin.parse_config_file('pretrained_models_base_operative_config.gin')\n",
    "vocab = 'sp10m.cased.t5.model'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dataset(split, shuffle_files = False):\n",
    "    del shuffle_files\n",
    "    ds = tf.data.TextLineDataset([split])\n",
    "\n",
    "    ds = ds.map(\n",
    "        functools.partial(\n",
    "            tf.io.decode_csv,\n",
    "            record_defaults = ['', ''],\n",
    "            field_delim = '\\t',\n",
    "            use_quote_delim = False,\n",
    "        ),\n",
    "        num_parallel_calls = tf.data.experimental.AUTOTUNE,\n",
    "    )\n",
    "    ds = ds.map(lambda *ex: dict(zip(['question', 'answer'], ex)))\n",
    "    return ds\n",
    "\n",
    "def preprocessor(ds):\n",
    "    def to_inputs_and_targets(ex):\n",
    "        return {'inputs': ex['question'], 'targets': ex['answer']}\n",
    "\n",
    "    return ds.map(\n",
    "        to_inputs_and_targets,\n",
    "        num_parallel_calls = tf.data.experimental.AUTOTUNE,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/husein/.local/lib/python3.6/site-packages/t5/models/mesh_transformer.py:210: UserWarning: get_sentencepiece_model_path is deprecated. Please pass the mixture or task vocabulary directly to the Mesh TensorFlow Transformer instead.\n",
      "  \"get_sentencepiece_model_path is deprecated. Please pass the mixture or \"\n"
     ]
    }
   ],
   "source": [
    "files = ['quora.tsv', 'mnli.tsv', 'snli.tsv']\n",
    "\n",
    "t5.data.TaskRegistry.remove('similarity_dataset')\n",
    "t5.data.TaskRegistry.add(\n",
    "    'similarity_dataset',\n",
    "    dataset_fn = dataset,\n",
    "    splits = ['train'],\n",
    "    text_preprocessor = [preprocessor],\n",
    "    sentencepiece_model_path = vocab,\n",
    "    metric_fns = [t5.evaluation.metrics.accuracy],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "quora.tsv\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "403831it [03:05, 2178.03it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mnli.tsv\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "387577it [02:56, 2189.92it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "snli.tsv\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "380288it [02:54, 2177.93it/s]\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "import os\n",
    "import json\n",
    "\n",
    "for file in files:\n",
    "    print(file)\n",
    "    f = os.path.split(file)[1]\n",
    "    nq_task = t5.data.TaskRegistry.get(\"similarity_dataset\")\n",
    "    ds = nq_task.get_dataset(split=file, sequence_length={\"inputs\": 1024, \"targets\": 1024})\n",
    "    results = []\n",
    "    for ex in tqdm(tfds.as_numpy(ds)):\n",
    "        results.append((ex['inputs'].tolist(), ex['targets'].tolist()))\n",
    "    \n",
    "    with open(f'{f}.parse', 'w') as fopen:\n",
    "        json.dump(results, fopen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
